# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py

import torch
from torch.nn import init

from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn


def rms_norm(x, weight, epsilon):
    return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
                                       False, True)


def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
                         layerscale=None, prenorm=False, residual_in_fp32=False,
                         return_dropout_mask=False):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormFn.apply(
        x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
        True, return_dropout_mask
    )


def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
                                  x0_subset=None, out_subset=None, rowscale_const=1.0,
                                  out_numrows=0, prenorm=False, residual_in_fp32=False,
                                  return_dropout_mask=False):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormSubsetFn.apply(
        x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
        rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
    )


def dropout_add_rms_norm_parallel_residual(
   x0, x1, residual, weight0, bias0, weight1, bias1,
   dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormParallelResidualFn.apply(
        x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
        True, return_dropout_mask
    )


class RMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)

    def forward(self, x):
        return rms_norm(x, self.weight, self.eps)


class DropoutAddRMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.prenorm = prenorm
        self.p = p
        self.eps = eps
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)

    def forward(self, x0, residual=None):
        return dropout_add_rms_norm(x0, residual, self.weight, None,
                                    self.p if self.training else 0.0, self.eps,
                                    prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
